"""
QUADRATIC layers
Here we define linear and quadratic layers that share subunits
"""

import tensorflow.compat.v1 as tf
import numpy as np
#tf.compat.v1.disable_eager_execution()

class Dense(object):
    """
    Dense linear unit
    This layer implements a affine transformation followed by pointwise
    nonlinearity. It also allows to define a self normalized version where
    each of the weight vectors that combine linearly with the input coefficient
    have unit norm.

    
    """
    def __init__(self, in_dim,
                 out_dim,
                 activation=None,
                 usebias=True,
                 weightnorm=False):
        """
        Inputs:
        -- in_dim : integer denoting the number of input dimensions 
        -- out_dim : integer denotign the number of output dimensions
        -- activation : (default None) pointwise nonlinearity function
        -- usebias : (optional) use bias in the affine transformation
        default value is True.
        -- weightnorm : (optional) normalizes the columns of the weighmatrix to 
        have unit norm
        
        """
        self.in_dim = in_dim
        self.n_units = out_dim
        self.usebias = usebias
        # initialize the paramters with the good stuff
        W_shape = [self.n_units, self.in_dim]
        #W_std = 1.0 / np.sqrt(self.n_units*self.in_dim)
        W_std = 1.0 / np.sqrt(self.in_dim)        
        self.W = tf.Variable(tf.random_normal(W_shape, stddev=W_std),
                             dtype=tf.float32)
        if self.usebias is True:
            self.bias = tf.Variable(tf.zeros((self.n_units, 1), dtype=tf.float32)) 
        self.activation = activation
        self.weightnorm = weightnorm

    def propagateForward(self, input):
        """
        Maps the input array to an output array with the sequence of operations
        defined in the layer
        
        Inputs:  
        -- input : an d-dimensional tensor where where the first axis is the
        the number of samples
        """ 
        if self.weightnorm is True:
            W_norm = tf.divide(self.W, tf.norm(self.W, axis=0, keep_dims=True))
            WX = tf.matmul(input, tf.transpose(W_norm))
        else:
            WX = tf.matmul(input, tf.transpose(self.W))
        if self.usebias is True:
            lin_term =  WX + tf.transpose(self.bias)
        else:
            lin_term = WX
        if self.activation is None:
            return lin_term 
        else:
            return self.activation(lin_term)
        
    def __call__(self, input):
        return self.propagateForward(input)
        
    def get_trainable_params(self):
        """
        Returns a tuple with parameters that can be tuned.
        This function is useful to manage list of paramters we want to tune or
        keep fixed during learning.
        """
        if self.usebias is True:
            return (self.W, self.bias)
        else:
            return (self.W, )

